{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C76CDbIJi2DY"
      },
      "source": [
        "# How to use CLIP Zero-Shot on your own classificaiton dataset\n",
        "\n",
        "This notebook provides an example of how to benchmark CLIP's zero shot classification performance on your own classification dataset.\n",
        "\n",
        "[CLIP](https://openai.com/blog/clip/) is a new zero shot image classifier relased by OpenAI that has been trained on 400 million text/image pairs across the web. CLIP uses these learnings to make predicts based on a flexible span of possible classification categories.\n",
        "\n",
        "CLIP is zero shot, that means **no training is required**. \n",
        "\n",
        "Try it out on your own task here!\n",
        "\n",
        "Be sure to experiment with various text prompts to unlock the richness of CLIP's pretraining procedure.\n",
        "\n",
        "\n",
        "![Roboflow Wordmark](https://i.imgur.com/dcLNMhV.png)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lOF3Feb7jrnu"
      },
      "source": [
        "# Download and Install CLIP Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FyUIWjzOi23X",
        "outputId": "62393789-10e6-4ef3-f7b4-10ae459506cf"
      },
      "source": [
        "#installing some dependencies, CLIP was release in PyTorch\n",
        "import subprocess\n",
        "\n",
        "CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n",
        "print(\"CUDA version:\", CUDA_version)\n",
        "\n",
        "if CUDA_version == \"10.0\":\n",
        "    torch_version_suffix = \"+cu100\"\n",
        "elif CUDA_version == \"10.1\":\n",
        "    torch_version_suffix = \"+cu101\"\n",
        "elif CUDA_version == \"10.2\":\n",
        "    torch_version_suffix = \"\"\n",
        "else:\n",
        "    torch_version_suffix = \"+cu110\"\n",
        "\n",
        "!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import os\n",
        "\n",
        "print(\"Torch version:\", torch.__version__)\n",
        "os.kill(os.getpid(), 9)\n",
        "#Your notebook process will restart after these installs"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CUDA version: 11.1\n",
            "Looking in links: https://download.pytorch.org/whl/torch_stable.html\n",
            "Collecting torch==1.7.1+cu110\n",
            "  Downloading https://download.pytorch.org/whl/cu110/torch-1.7.1%2Bcu110-cp37-cp37m-linux_x86_64.whl (1156.8 MB)\n",
            "\u001b[K     |███████████████████████         | 834.1 MB 1.6 MB/s eta 0:03:18tcmalloc: large alloc 1147494400 bytes == 0x563157d14000 @  0x7f18440fd615 0x5631543e84cc 0x5631544c847a 0x5631543eb2ed 0x5631544dce1d 0x56315445ee99 0x5631544599ee 0x5631543ecbda 0x56315445ed00 0x5631544599ee 0x5631543ecbda 0x56315445b737 0x5631544ddc66 0x56315445adaf 0x5631544ddc66 0x56315445adaf 0x5631544ddc66 0x56315445adaf 0x5631543ed039 0x563154430409 0x5631543ebc52 0x56315445ec25 0x5631544599ee 0x5631543ecbda 0x56315445b737 0x5631544599ee 0x5631543ecbda 0x56315445a915 0x5631543ecafa 0x56315445ac0d 0x5631544599ee\n",
            "\u001b[K     |█████████████████████████████▏  | 1055.7 MB 1.5 MB/s eta 0:01:07tcmalloc: large alloc 1434370048 bytes == 0x56319c36a000 @  0x7f18440fd615 0x5631543e84cc 0x5631544c847a 0x5631543eb2ed 0x5631544dce1d 0x56315445ee99 0x5631544599ee 0x5631543ecbda 0x56315445ed00 0x5631544599ee 0x5631543ecbda 0x56315445b737 0x5631544ddc66 0x56315445adaf 0x5631544ddc66 0x56315445adaf 0x5631544ddc66 0x56315445adaf 0x5631543ed039 0x563154430409 0x5631543ebc52 0x56315445ec25 0x5631544599ee 0x5631543ecbda 0x56315445b737 0x5631544599ee 0x5631543ecbda 0x56315445a915 0x5631543ecafa 0x56315445ac0d 0x5631544599ee\n",
            "\u001b[K     |████████████████████████████████| 1156.7 MB 1.3 MB/s eta 0:00:01tcmalloc: large alloc 1445945344 bytes == 0x563236a80000 @  0x7f18440fd615 0x5631543e84cc 0x5631544c847a 0x5631543eb2ed 0x5631544dce1d 0x56315445ee99 0x5631544599ee 0x5631543ecbda 0x56315445ac0d 0x5631544599ee 0x5631543ecbda 0x56315445ac0d 0x5631544599ee 0x5631543ecbda 0x56315445ac0d 0x5631544599ee 0x5631543ecbda 0x56315445ac0d 0x5631544599ee 0x5631543ecbda 0x56315445ac0d 0x5631543ecafa 0x56315445ac0d 0x5631544599ee 0x5631543ecbda 0x56315445b737 0x5631544599ee 0x5631543ecbda 0x56315445b737 0x5631544599ee 0x5631543ed271\n",
            "\u001b[K     |████████████████████████████████| 1156.8 MB 14 kB/s \n",
            "\u001b[?25hCollecting torchvision==0.8.2+cu110\n",
            "  Downloading https://download.pytorch.org/whl/cu110/torchvision-0.8.2%2Bcu110-cp37-cp37m-linux_x86_64.whl (12.9 MB)\n",
            "\u001b[K     |████████████████████████████████| 12.9 MB 26.8 MB/s \n",
            "\u001b[?25hCollecting ftfy\n",
            "  Downloading ftfy-6.0.3.tar.gz (64 kB)\n",
            "\u001b[K     |████████████████████████████████| 64 kB 1.8 MB/s \n",
            "\u001b[?25hRequirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (2019.12.20)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.7.1+cu110) (3.7.4.3)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch==1.7.1+cu110) (1.19.5)\n",
            "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from torchvision==0.8.2+cu110) (7.1.2)\n",
            "Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from ftfy) (0.2.5)\n",
            "Building wheels for collected packages: ftfy\n",
            "  Building wheel for ftfy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for ftfy: filename=ftfy-6.0.3-py3-none-any.whl size=41933 sha256=035e5c52778aabac6248e9b483e8a4e00cf589b31921197ff717ef82e27055b0\n",
            "  Stored in directory: /root/.cache/pip/wheels/19/f5/38/273eb3b5e76dfd850619312f693716ac4518b498f5ffb6f56d\n",
            "Successfully built ftfy\n",
            "Installing collected packages: torch, torchvision, ftfy\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 1.9.0+cu111\n",
            "    Uninstalling torch-1.9.0+cu111:\n",
            "      Successfully uninstalled torch-1.9.0+cu111\n",
            "  Attempting uninstall: torchvision\n",
            "    Found existing installation: torchvision 0.10.0+cu111\n",
            "    Uninstalling torchvision-0.10.0+cu111:\n",
            "      Successfully uninstalled torchvision-0.10.0+cu111\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "torchtext 0.10.0 requires torch==1.9.0, but you have torch 1.7.1+cu110 which is incompatible.\u001b[0m\n",
            "Successfully installed ftfy-6.0.3 torch-1.7.1+cu110 torchvision-0.8.2+cu110\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cqN0UVpssA7J",
        "outputId": "3098922c-d340-44c8-f7ba-76a76ffd4969"
      },
      "source": [
        "#clone the CLIP repository\n",
        "!git clone https://github.com/openai/CLIP.git\n",
        "%cd CLIP"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'CLIP'...\n",
            "remote: Enumerating objects: 164, done.\u001b[K\n",
            "remote: Counting objects: 100% (73/73), done.\u001b[K\n",
            "remote: Compressing objects: 100% (49/49), done.\u001b[K\n",
            "remote: Total 164 (delta 31), reused 49 (delta 19), pack-reused 91\u001b[K\n",
            "Receiving objects: 100% (164/164), 8.87 MiB | 27.51 MiB/s, done.\n",
            "Resolving deltas: 100% (71/71), done.\n",
            "/content/CLIP\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hwCkgL73rHE0"
      },
      "source": [
        "# Download Classification Data or Object Detection Data\n",
        "\n",
        "We will download the [public flowers classificaiton dataset](https://public.roboflow.com/classification/flowers_classification) from Roboflow. The data will come out as folders broken into train/valid/test splits and seperate folders for each class label.\n",
        "\n",
        "You can easily download your own dataset from Roboflow in this format, too.\n",
        "\n",
        "We made a conversion from object detection to CLIP text prompts in Roboflow, too, if you want to try that out.\n",
        "\n",
        "\n",
        "To get your data into Roboflow, follow the [Getting Started Guide](https://blog.roboflow.ai/getting-started-with-roboflow/)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JWy5SOmV5tLK",
        "outputId": "b2d18619-86ff-48ec-d339-0f649a0284a0"
      },
      "source": [
        "#follow the link below to get your download code from from Roboflow\n",
        "!pip install -q roboflow\n",
        "from roboflow import Roboflow\n",
        "rf = Roboflow(model_format=\"clip\", notebook=\"roboflow-clip\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[K     |████████████████████████████████| 178 kB 9.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 1.1 MB 40.4 MB/s \n",
            "\u001b[K     |████████████████████████████████| 138 kB 51.3 MB/s \n",
            "\u001b[K     |████████████████████████████████| 636 kB 40.4 MB/s \n",
            "\u001b[K     |████████████████████████████████| 62 kB 809 kB/s \n",
            "\u001b[?25h  Building wheel for roboflow (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for wget (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "torchtext 0.10.0 requires torch==1.9.0, but you have torch 1.7.1+cu110 which is incompatible.\n",
            "google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.26.0 which is incompatible.\n",
            "datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.\n",
            "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n",
            "upload and label your dataset, and get an API KEY here: https://app.roboflow.com/?model=yolov5&ref=roboflow-yolov5\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AHHnCdsKrCVF",
        "outputId": "513c2529-fe14-4d2c-ba4f-38443136eed2"
      },
      "source": [
        "#download classification data\n",
        "# from roboflow import Roboflow\n",
        "# rf = Roboflow(api_key=\"YOUR_API_KEY\")\n",
        "# project = rf.workspace().project(\"YOUR_PROJECT\")\n",
        "# dataset = project.version(\"YOUR_VERSION\").download(\"clip\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading Roboflow workspace...\n",
            "loading Roboflow project...\n",
            "Downloading Dataset Version Zip in Rock-Paper-Scissors-1 to clip: 100% [12121758 / 12121758] bytes\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Extracting Dataset Version Zip to Rock-Paper-Scissors-1 in clip:: 100%|██████████| 2941/2941 [00:02<00:00, 1034.21it/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "sFsApb_r65Us",
        "outputId": "bf35519d-38d0-468f-9be5-9b54aa0b885f"
      },
      "source": [
        "dataset.location"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'/content/CLIP/Rock-Paper-Scissors-1'"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-m7FYTSjCN46",
        "outputId": "2dc044ad-1b15-4d6c-c97c-c74846f02168"
      },
      "source": [
        "import os\n",
        "#our the classes and images we want to test are stored in folders in the test set\n",
        "class_names = os.listdir(dataset.location + '/test/')\n",
        "class_names.remove('_tokenization.txt')\n",
        "class_names"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['scissors', 'rock', 'paper']"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QELzUB7pnr-h",
        "outputId": "471da341-412b-4f1f-9161-0ba7fcaebd82"
      },
      "source": [
        "#we auto generate some example tokenizations in Roboflow but you should edit this file to try out your own prompts\n",
        "#CLIP gets a lot better with the right prompting!\n",
        "#be sure the tokenizations are in the same order as your class_names above!\n",
        "%cat {dataset.location}/test/_tokenization.txt"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "An example picture from the Rock Paper Scissors dataset depicting a paper\n",
            "An example picture from the Rock Paper Scissors dataset depicting a rock\n",
            "An example picture from the Rock Paper Scissors dataset depicting a scissors"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jPQxIGxXn8bR",
        "outputId": "59f78f4e-2222-48bd-d947-37a4fdfe9203"
      },
      "source": [
        "#edit your prompts as you see fit here, be sure the classes are in teh same order as above\n",
        "%%writefile {dataset.location}/test/_tokenization.txt\n",
        "The paper sign in rock paper scissors\n",
        "The rock sign in rock paper scissors\n",
        "The scissors sign in rock paper scissors"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Overwriting /content/CLIP/Rock-Paper-Scissors-1/test/_tokenization.txt\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gpXaqPH3oSyO"
      },
      "source": [
        "candidate_captions = []\n",
        "with open(dataset.location + '/test/_tokenization.txt') as f:\n",
        "    candidate_captions = f.read().splitlines()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vMZJy1SduxiE"
      },
      "source": [
        "# Run CLIP inference on your classification dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yAi7cvucnFPr",
        "outputId": "edb02089-53e3-4629-eb2f-670ebc1304bc"
      },
      "source": [
        "import torch\n",
        "import clip\n",
        "from PIL import Image\n",
        "import glob\n",
        "\n",
        "def argmax(iterable):\n",
        "    return max(enumerate(iterable), key=lambda x: x[1])[0]\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "model, transform = clip.load(\"ViT-B/32\", device=device)\n",
        "\n",
        "correct = []\n",
        "\n",
        "#define our target classificaitons, you can should experiment with these strings of text as you see fit, though, make sure they are in the same order as your class names above\n",
        "text = clip.tokenize(candidate_captions).to(device)\n",
        "\n",
        "for cls in class_names:\n",
        "    class_correct = []\n",
        "    test_imgs = glob.glob(dataset.location + '/test/' + cls + '/*.jpg')\n",
        "    for img in test_imgs:\n",
        "        #print(img)\n",
        "        image = transform(Image.open(img)).unsqueeze(0).to(device)\n",
        "        with torch.no_grad():\n",
        "            image_features = model.encode_image(image)\n",
        "            text_features = model.encode_text(text)\n",
        "            \n",
        "            logits_per_image, logits_per_text = model(image, text)\n",
        "            probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n",
        "\n",
        "            pred = class_names[argmax(list(probs)[0])]\n",
        "            #print(pred)\n",
        "            if pred == cls:\n",
        "                correct.append(1)\n",
        "                class_correct.append(1)\n",
        "            else:\n",
        "                correct.append(0)\n",
        "                class_correct.append(0)\n",
        "    \n",
        "    print('accuracy on class ' + cls + ' is :' + str(sum(class_correct)/len(class_correct)))\n",
        "print('accuracy on all is : ' + str(sum(correct)/len(correct)))\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "accuracy on class scissors is :0.5454545454545454\n",
            "accuracy on class rock is :0.18181818181818182\n",
            "accuracy on class paper is :0.0\n",
            "accuracy on all is : 0.24242424242424243\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4yrqnKyUofrb"
      },
      "source": [
        "#Hope you enjoyed!\n",
        "#As always, happy inferencing\n",
        "#Roboflow"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "SslAXBoOvF8L"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}